Skip to content

feat(sa-bench): add sglang DeepSeek-V4 tokenizer#73

Merged
ishandhanani merged 5 commits into
NVIDIA:mainfrom
YAMY1234:yangminl/dsv4-sglang-tokenizer-v2
Apr 24, 2026
Merged

feat(sa-bench): add sglang DeepSeek-V4 tokenizer#73
ishandhanani merged 5 commits into
NVIDIA:mainfrom
YAMY1234:yangminl/dsv4-sglang-tokenizer-v2

Conversation

@YAMY1234
Copy link
Copy Markdown
Collaborator

Summary

Adds a client-side tokenizer for DeepSeek-V4-Pro that matches sglang server behavior exactly, so sa-bench's input_tokens lines up with the server's #new-token.

Uses the existing module.path.ClassName hook in backend_request_func.get_tokenizer; no changes to sa-bench source are required.

Motivation

DeepSeek-V4 ships no Hugging Face chat template, so tokenizer.apply_chat_template() raises ValueError. sglang's server solves this internally by replacing the HF-template path with a hard-coded DSML encoder (encoding_dsv4.encode_messages) whenever arch == "DeepseekV4ForCausalLM" (see sgl-project/sglang#23600). Without a matching client-side encoder, sa-bench silently falls back to raw-text tokenization and its counts diverge from the server's — skewing ISL, TPOT, TTFT, and MTP accept-rate accounting.

Design

Two files under sa-bench/sa_bench_tokenizers/:

  • _sglang_encoding_dsv4.py — vendored byte-exact from sgl-project/sglang commit f5d03db853862c8fb0e805df591bed883a71868b. Apache-2.0 header + upstream SHA preserved, so the file can be dropped when sglang publishes an official client-side package.
  • sglang_deepseek_v4.py — HF-compatible wrapper. Its apply_chat_template mirrors serving_chat exactly:
    1. Insert an empty system message if missing.
    2. thinking_mode="chat", reasoning_effort=None (sglang defaults).
    3. Call the vendored encode_messages(...) to render the raw DSML string.
    4. hf_tokenizer.encode(..., add_special_tokens=False) (the encoder already emits <|begin▁of▁sentence|>).

Naming: the package is sa_bench_tokenizers (not tokenizers) to avoid shadowing the HuggingFace tokenizers top-level package when sa-bench is run with transformers/vllm installed.

Usage (recipe)

benchmark:
  type: "sa-bench"
  custom_tokenizer: "sa_bench_tokenizers.sglang_deepseek_v4.SGLangDeepseekV4Tokenizer"
  use_chat_template: true
  ...

Files

  • src/srtctl/benchmarks/scripts/sa-bench/sa_bench_tokenizers/__init__.py
  • src/srtctl/benchmarks/scripts/sa-bench/sa_bench_tokenizers/_sglang_encoding_dsv4.py (vendored, Apache-2.0)
  • src/srtctl/benchmarks/scripts/sa-bench/sa_bench_tokenizers/sglang_deepseek_v4.py

Relationship to other PRs

Test plan

  • python -m py_compile passes on all three files.
  • Offline equivalence check: SGLangDeepseekV4Tokenizer(...).apply_chat_template(messages) token IDs match server-side tokenizer.encode(encode_messages(messages)) for representative prompts.
  • Online smoke: short sa-bench run against a DeepSeek-V4-Pro sglang recipe, asserting client input_tokens == server #new-token.

Adds a client-side tokenizer for DeepSeek-V4-Pro that matches sglang
server behavior. Usable via the existing 'module.path.ClassName' hook
in backend_request_func.get_tokenizer; no changes to sa-bench itself.

Motivation: DeepSeek-V4 ships no HF chat template, so
tokenizer.apply_chat_template() raises ValueError. Sglang's server
replaces the HF path with a hard-coded DSML encoder
(encoding_dsv4.encode_messages) whenever arch == 'DeepseekV4ForCausalLM',
per sgl-project/sglang PR #23600. Without a matching client-side
encoder, sa-bench input_tokens diverges from server #new-token.

Implementation:
  - sa_bench_tokenizers/_sglang_encoding_dsv4.py: vendored byte-exact
    from sgl-project/sglang@f5d03db (Apache-2.0, 840 lines).
  - sa_bench_tokenizers/sglang_deepseek_v4.py: HF-compatible wrapper.
    apply_chat_template() mirrors serving_chat exactly:
      1) insert empty system message if missing
      2) thinking_mode='chat', reasoning_effort=None (defaults)
      3) call encode_messages(...)
      4) hf_tokenizer.encode(..., add_special_tokens=False)

Usage (recipe):
  benchmark:
    custom_tokenizer: "sa_bench_tokenizers.sglang_deepseek_v4.SGLangDeepseekV4Tokenizer"

Recipes: recipe YAML authoring is out of scope; see NVIDIA#70
for DeepSeek-V4-Pro sglang recipes.
@YAMY1234 YAMY1234 marked this pull request as draft April 24, 2026 17:26
@codecov-commenter
Copy link
Copy Markdown

Codecov Report

✅ All modified and coverable lines are covered by tests.
⚠️ Please upload report for BASE (main@54badf2). Learn more about missing BASE report.

Additional details and impacted files
@@           Coverage Diff           @@
##             main      #73   +/-   ##
=======================================
  Coverage        ?   70.35%           
=======================================
  Files           ?       59           
  Lines           ?     6270           
  Branches        ?        0           
=======================================
  Hits            ?     4411           
  Misses          ?     1859           
  Partials        ?        0           

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

AutoTokenizer.from_pretrained rejects checkpoints whose model_type
(deepseek_v4) is not yet registered in mainline transformers. The V4
checkpoint ships a ready-made tokenizer.json, so prefer loading it
directly through PreTrainedTokenizerFast and only fall back to
AutoTokenizer (for future transformers releases that register V4).

Verified offline on DeepSeek-V4-Pro: 4 representative prompts
(hello, GSM8K, system+user, multi-turn) each produce client token IDs
byte-identical to the sglang server path
tokenizer.encode(encode_messages(msgs_with_empty_system,
thinking_mode=chat)).
@YAMY1234 YAMY1234 force-pushed the yangminl/dsv4-sglang-tokenizer-v2 branch from 163b864 to 1376701 Compare April 24, 2026 17:44
…V4 tokenizer

Minimal recipe that exercises the new SGLangDeepseekV4Tokenizer wrapper end
to end on a single GB300 node (TP=4, MTP 3/4, MXFP4 MoE). Used to verify
client-side prompt encoding aligns with the SGLang server for DeepSeek-V4.

Key knobs:
- mem-fraction-static: 0.82 (required on 1-node: DSv4-Pro weights occupy
  ~206 GB/GPU, so a lower mfs leaves negative KV pool after SGLang reserves
  (1-mfs)*total_mem for activations, triggering "Not enough memory")
- use_chat_template: true
- custom_tokenizer: sa_bench_tokenizers.sglang_deepseek_v4.SGLangDeepseekV4Tokenizer
@YAMY1234
Copy link
Copy Markdown
Collaborator Author

End-to-end evidence that the new tokenizer wires up correctly

Smoke run: gb300 1-node TP=4, DeepSeek-V4-Pro, MTP 3/4 (speculative-num-steps=3, speculative-num-draft-tokens=4), EAGLE topk=1, random 1k/1k, recipe = recipes/gb300-fp4/1k1k-dsv4/agg-low-latency-chat.yaml.

1. Client actually loads our wrapper and enables apply_chat_template

Taken from the main benchmark run (num_prompts=10, save_result=True) of concurrency=1, sa-bench/benchmark_serving.py prints its parsed Namespace:

Namespace(backend='dynamo', ... endpoint='/v1/completions',
  num_prompts=10, save_result=True, result_filename='results_concurrency_1_gpus_4.json',
  random_input_len=1024, random_output_len=1024, random_range_ratio=0.8,
  use_chat_template=True,
  custom_tokenizer='sa_bench_tokenizers.sglang_deepseek_v4.SGLangDeepseekV4Tokenizer',
  ...)

Both flags we care about are set: use_chat_template=True and custom_tokenizer= the wrapper this PR adds. No modifications to backend_request_func.py — the existing module.path.ClassName hook just works.

2. sglang server accepts the DSML-rendered prompts and MTP is live

From theia0259_agg_w0.out (server-side):

[2026-04-24 11:13:49 TP0] Prefill batch, #new-seq: 1, #new-token: 1024, ...
[2026-04-24 11:13:49 TP0] Decode batch, #running-req: 1, accept len: 3.00, accept rate: 0.75,
                                        cuda graph: True, gen throughput (token/s): 144.37, ...

Prefill reports #new-token: 1024 for an isl=1024 request — i.e. the tokenization on the server BPE-encodes our DSML string back to the same count sa-bench assumed client-side, so per-token metrics (TPOT, TTFT, output-throughput) are on consistent footing.

Full disclosure: the container ships SGLang from before sgl-project/sglang#23600 is merged, so its own chat-template lookup still prints No HuggingFace chat template found for /v1/chat/completions. That's exactly why sa-bench needs the client-side DSML renderer this PR adds — we run over /v1/completions with an already-DSML-rendered prompt string, and the server's fast tokenizer resolves the DeepSeek-V4 special tokens from tokenizer.json into single-token IDs. This is functionally equivalent to what the server-side encoder will do once #23600 lands.

3. MTP accept-length, with vs. without chat template

Pulled from decode_log_interval=1 lines in the _agg_w0.out worker logs, partitioned into concurrency phases by the benchmark client's timestamps.

With this PR (use_chat_template=true, custom tokenizer) — gb300 1-node, isl=osl=1024, job 1573123:

concurrency decode steps avg accept_len avg accept_rate
1 3946 2.610 0.652
2 3861 2.771 0.693
4 3983 2.637 0.659
8 4536 2.650 0.662

Without chat template (prior baseline, use_chat_template=false) — same HW/recipe family, gb300 1-node, isl=osl=1024, reference job 1567633 (official agg-low-latency-1k1k):

concurrency decode steps avg accept_len avg accept_rate
1 6499 3.338 0.835
2 3318 3.491 0.873
4 1912 3.439 0.861
8 1316 3.127 0.782
16 1018 3.162 0.791
32 986 2.996 0.749
64 1290 3.049 0.762
128 1558 3.192 0.798

Without chat template, sa-bench feeds the model a raw tokenizer.decode(random_ids) prompt. That is easy for the draft model to extrapolate, so accept_len runs 3.0–3.5. Once the prompt is wrapped in DSML (<|begin▁of▁sentence|>, <|User|>, turn boundaries, etc.) the draft model sees a realistic prefix distribution and accept_len settles around 2.6–2.8 — that is the number you should report for DeepSeek-V4-Pro MTP on 1k/1k chat traffic. The pre-PR 3.0–3.5 overstated it.

4. Offline equivalence check (what #23600 would produce)

SGLangDeepseekV4Tokenizer(...).apply_chat_template(messages) token IDs are byte-identical to the simulated server path encoding_dsv4.encode_messages(messages)hf_tokenizer.encode(..., add_special_tokens=False) on 4 representative prompts (gsm8k-style few-shot, multi-turn chat, tool-call, empty-system). Same HF fast tokenizer from tokenizer.json on both sides.

@YAMY1234 YAMY1234 marked this pull request as ready for review April 24, 2026 21:04
Required by NVIDIA/srt-slurm CI license check.
…RT env fallback

The client-side DSV4 tokenizer renders prompts via the vendored
``encoding_dsv4.encode_messages`` and previously hard-coded
``thinking_mode="chat"`` and ``reasoning_effort=None``. This matches the
sglang server only when the user has not set the DSV4 thinking knobs.

sglang ``serving_chat.py`` (PR #23600) honors two envs as fallbacks:

- ``SGLANG_ENABLE_THINKING=1`` -> ``thinking_mode="thinking"``
- ``SGLANG_REASONING_EFFORT=max|high`` -> passed through to the encoder

Recipes typically set these in ``prefill_environment`` /
``decode_environment`` for reasoning eval workloads (gpqa, aime25, etc.)
on DeepSeek-V4-Flash. Without matching fallback on the sa-bench client,
server prompts would be wrapped in ``<think>...</think>`` while the
client still rendered the chat template, desynchronizing ISL / TPOT /
MTP accept-rate accounting.

This change:

- Adds ``_env_enable_thinking()`` / ``_env_reasoning_effort()`` helpers
  that parse the envs the same way sglang ``EnvBool`` / ``EnvStr`` do
  (``{1,true,yes,on}`` truthy; ``max|high`` filter).
- Changes ``apply_chat_template`` defaults from Python literals to
  ``None`` sentinels; when ``None``, falls back to env. Explicit caller
  kwargs (incl. ``thinking=False``) still win, matching the server's
  ``(request.chat_template_kwargs or {}).get("thinking", env_default)``
  precedence.
- Expands the docstring to show the real server call chain (not just
  the happy-path defaults).

Smoke-tested all five precedence cases (no env + no kwarg / env on +
no kwarg / env on + explicit False / env off + explicit True / bogus
env value filtered).

Made-with: Cursor
@ishandhanani ishandhanani merged commit 99d6195 into NVIDIA:main Apr 24, 2026
6 checks passed
ishandhanani pushed a commit that referenced this pull request Apr 26, 2026
… lacks chat_template (#76)

Two related fixes for sa-bench when running models without a jinja chat
template (e.g. DeepSeek-V4-Pro):

1. benchmark_serving.py: when --use-chat-template is set but the loaded
   tokenizer has neither a jinja chat_template nor an overridden
   apply_chat_template method, fail fast with a clear message pointing
   to either SGLangDeepseekV4Tokenizer (#73) or use_chat_template: false.
   Previously this crashed deep inside transformers with a generic
   ValueError that gave no hint how to fix the recipe.

2. bench.sh: warmup runs were missing CHAT_TEMPLATE_ARGS, so warmup
   always ran without chat template even when the main run had it
   enabled -- leading to mismatched cache state between warmup and
   measurement. Also adds an early-exit notice when use_chat_template
   is true but no custom_tokenizer is configured.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants